import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
# Change mathplotlib style
plt.style.use('seaborn-whitegrid')
# Static images of the plots are embedded in the notebook
%matplotlib inline
Basic Plot Example
x = np.linspace(start=0, stop=10, num=100) # evenly spaced numbers
fig = plt.figure()
plt.plot(x, np.sin(x), '-')
plt.plot(x, np.cos(x), '--');
Saving Figures to File
# Save Image
fig.savefig('my_figure.png')
# Load Image
from IPython.display import Image
Image('my_figure.png')
# Supported figure canvas objects
fig.canvas.get_supported_filetypes()
# MATLAB-style interface
plt.figure() # Create a plot figure
# Create the first of two panels and set current axis
plt.subplot(2, 1, 1) # (rows, columns, panel number)
plt.plot(x, np.sin(x))
# Create the second panel and set current axis
plt.subplot(2, 1, 2)
plt.plot(x, np.cos(x));
# Interface is 'stateful', it keeps track of the current figure and axes
# plt.gcf(): get current figure
# plt.gca(): get current axes
# Object oriented interface
# First create a grid of plots
# ax will be an array of two Axes objects
fig, ax = plt.subplots(2)
# Call plot() method on the appropriate object
ax[0].plot(x, np.sin(x))
ax[1].plot(x, np.cos(x));
fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 10, 1000)
ax.plot(x, np.sin(x));
# Over-plotting multiple lines
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x));
Adjusting the Plot: Line Colors and Styles
plt.plot(x, np.sin(x - 0), color='blue') # specify color by name
plt.plot(x, np.sin(x - 1), color='g') # short color code (rgbcmyk)
plt.plot(x, np.sin(x - 2), color='0.75') # Grayscale between 0 and 1
plt.plot(x, np.sin(x - 3), color='#FFDD44') # Hex code (RRGGBB from 00 to FF)
plt.plot(x, np.sin(x - 4), color=(1.0,0.2,0.3)); # RGB tuple, values 0 and 1 plt.plot(x, np.sin(x - 5), color='chartreuse'); # all HTML color names supported
plt.plot(x, x + 0, linestyle='solid')
plt.plot(x, x + 1, linestyle='dashed')
plt.plot(x, x + 2, linestyle='dashdot')
plt.plot(x, x + 3, linestyle='dotted');
# For short, you can use the following codes:
plt.plot(x, x + 4, linestyle='-') # solid
plt.plot(x, x + 5, linestyle='--') # dashed
plt.plot(x, x + 6, linestyle='-.') # dashdot
plt.plot(x, x + 7, linestyle=':'); # dotted
Adjusting the Plot: Axes Limits
# Example of setting axis limits
plt.plot(x, np.sin(x))
plt.xlim(-1, 11)
plt.ylim(-1.5, 1.5);
# Example of reversing the y-axis
plt.plot(x, np.sin(x))
# Reversing the arguments
plt.xlim(10, 0)
plt.ylim(1.2, -1.2);
# Change both axis in a single call
plt.plot(x, np.sin(x))
plt.axis([-1, 11, -1.5, 1.5]);
# Automatically tight the graph axis
plt.plot(x, np.sin(x))
plt.axis('tight')
# Example of an “equal” layout, with units matched to the output resolution
# Equal aspect ratio
plt.plot(x, np.sin(x))
plt.axis('equal');
Labeling Plots
# Examples of axis labels and title
plt.plot(x, np.sin(x))
plt.title("A Sine Curve")
plt.xlabel("x")
plt.ylabel("sin(x)");
# Plot legend example
plt.plot(x, np.sin(x), '-g', label='sin(x)')
plt.plot(x, np.cos(x), ':b', label='cos(x)')
plt.axis('equal')
plt.legend();
# Mapping between MATLAB and object-oriented style
# plt.xlabel() → ax.set_xlabel()
# plt.ylabel() → ax.set_ylabel()
# plt.xlim() → ax.set_xlim()
# plt.ylim() → ax.set_ylim()
# plt.title() → ax.set_title()
# Object-oriented interface to plotting
ax = plt.axes()
ax.plot(x, np.sin(x))
ax.set(xlim=(0, 10), ylim=(-2, 2),
xlabel='x', ylabel='sin(x)',
title='A Simple Plot');
# Scatter plot example
x = np.linspace(0, 10, 30)
y = np.sin(x)
plt.plot(x, y, 'o', color='black');
# Demonstration of point numbers
rng = np.random.RandomState(0)
for marker in ['o', '.', ',', 'x', '+', 'v', '^', '<', '>', 's', 'd']:
plt.plot(rng.rand(5), rng.rand(5), marker,
label="marker='{0}'".format(marker))
plt.legend(numpoints=1)
plt.xlim(0, 1.8);
# Combining line and point markers
plt.plot(x, y, '-ok'); # line (-), circle marker (o), black (k)
# Customizing line and point numbers
plt.plot(x, y, '-p', color='gray',
markersize=15, linewidth=4,
markerfacecolor='white',
markeredgecolor='gray',
markeredgewidth=2)
plt.ylim(-1.2, 1.2);
A second, more powerful method of creating scatter plots is the plt.scatter()
plt.scatter(x, y, marker='o');
# Changing size, color, and transparency in scatter points
rng = np.random.RandomState(0)
x = rng.randn(100)
y = rng.randn(100)
colors = rng.rand(100)
sizes = 1000 * rng.rand(100)
plt.scatter(x, y, c=colors, s=sizes, alpha=0.3,
cmap='viridis')
plt.colorbar(); # show color scale
# Using point properties to encode features of the Iris data
from sklearn.datasets import load_iris
iris = load_iris()
features = iris.data.T
plt.scatter(features[0], features[1], alpha=0.2,
s=100*features[3], c=iris.target, cmap='viridis')
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1]);
# plt.plot should be preferred over plt.scatter for large datasets.
# The plt.scatter determines the appearence for each individual point
# An errorbar example
x = np.linspace(0, 10, 50)
dy = 0.8
y = np.sin(x) + dy * np.random.randn(50)
plt.errorbar(x, y, yerr=dy, fmt='.k');
# Customizing errorbars aesthetic
plt.errorbar(x, y, yerr=dy, fmt='o', color='black',
ecolor='lightgray', elinewidth=3, capsize=0);
# I often find it helpful, especially in crowded plots, to make
# the errorbars lighter than the points themselves
# GaussianProcess is deprecated
from sklearn import GaussianProcess
# define the model and draw some data
model = lambda x: x * np.sin(x)
xdata = np.array([1, 3, 5, 6, 8])
ydata = model(xdata)
# Compute the Gaussian process fit
gp = GaussianProcess(corr='cubic', theta0=1e-2, thetaL=1e-4, thetaU=1E-1,
random_start=100)
gp.fit(xdata[:, np.newaxis], ydata)
xfit = np.linspace(0, 10, 1000)
yfit, MSE = gp.predict(xfit[:, np.newaxis], eval_MSE=True)
dyfit = 2 * np.sqrt(MSE) # 2*sigma ~ 95% confidence region
# Visualize the result plt.plot(xdata, ydata, 'or')
# Representing continuous uncertainty with filled regions
plt.plot(xfit, yfit, '-', color='gray')
plt.fill_between(xfit, yfit - dyfit, yfit + dyfit,
color='gray', alpha=0.2)
plt.xlim(0, 10);
# Sometimes is useful display three-dimensional data in
# two dimensions using contours or color-coded regions
# plt.contour for contour plots
# plt.contourf for filled contour plots
# plt.imshow for showing images
# Contour plot using a function z = f(x,y)
def f(x, y):
return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)
# The x and y values represent positions on the plot, and
# the z values will be represented by the contour levels
x = np.linspace(0, 5, 50)
y = np.linspace(0, 5, 40)
# the most straightforward way to prepare such data is to use the
# np.meshgrid function, which builds two-dimensional grids from
# one-dimensional arrays
X, Y = np.meshgrid(x, y)
Z=f(X,Y)
# Plot
contour = plt.contour(X, Y, Z, colors='black');
plt.clabel(contour, inline=True, fontsize=8)
# By default when a single color is used, negative values are
# represented by dashed lines, and positive values by solid lines.
# Visualizing three-dimensional data with colored contours
# color-code the lines by specifying a colormap with the cmap argument
# lines to be drawn—20 equally spaced intervals within the data range
plt.contour(X, Y, Z, 30, cmap='RdGy');
# Visualizing three-dimensional data with filled contours
plt.contourf(X, Y, Z, 20, cmap='RdGy')
plt.colorbar(); # Creates an index
# The colorbar makes it clear that the black regions are “peaks,
# ” while the red regions are “valleys.”
# Representing three-dimensional data as an image
# The number of contours to a very high number, but this results
# in a rather inefficient plot:
# Matplotlib must render a new polygon for each step in the level.
# A better way to handle this is to use the plt.imshow() function,
# which inter‐ prets a two-dimensional grid of data as an image.
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower', cmap='RdGy')
plt.colorbar()
plt.axis(aspect='image');
# plt.imshow() doesn’t accept an x and y grid, so you must manually
# specify the extent [xmin, xmax, ymin, ymax] of the image on the plot.
# plt.imshow() by default follows the standard image array definition
# where the origin is in the upper left, not in the lower left as in
# most contour plots. This must be changed when showing gridded data.
# plt.imshow() will automatically adjust the axis aspect ratio to match
# the input data; you can change this by setting, for example,
# plt.axis(aspect='image') to make x and y units match.
# Labeled contours on top of an image
# Partially transparent background image
# Over-plot contours with labels on the contours themselves
contours = plt.contour(X, Y, Z, 3, colors='black')
plt.clabel(contours, inline=True, fontsize=8)
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower',
cmap='RdGy', alpha=0.5)
plt.colorbar();
# A simple histogram
data = np.random.randn(1000)
plt.hist(data);
# More advanced histogram
# Density histogram
plt.hist(data, bins=30, density=True, alpha=0.5,
histtype='stepfilled', color='steelblue',
edgecolor='none');
# Over-plotting multiple histograms
# histtype='stepfilled' along with some transparency alpha
# useful when comparing histograms of several distributions
x1 = np.random.normal(0, 0.8, 1000)
x2 = np.random.normal(-2, 1, 1000)
x3 = np.random.normal(3, 2, 1000)
kwargs = dict(histtype='stepfilled', alpha=0.3, normed=True, bins=40)
plt.hist(x1, **kwargs)
plt.hist(x2, **kwargs)
plt.hist(x3, **kwargs);
# Count the number of points in each bin
counts, bin_edges = np.histogram(data, bins=5)
print(counts)
# we can create histograms in two dimensions by
# dividing points among two- dimensional bins
mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 10000).T
# A two-dimensional histogram with plt.hist2d
plt.hist2d(x, y, bins=30, cmap='Blues')
cb = plt.colorbar()
cb.set_label('counts in bin')
# A two-dimensional histogram with plt.hexbin
# Another natural shape for such a tessellation is the regular hexagon.
plt.hexbin(x, y, gridsize=30, cmap='Blues')
cb = plt.colorbar(label='count in bin')
# A kernel density representation of a distribution
# Another common method of evaluating densities
# in multiple dimensions is kernel density estimation (KDE).
from scipy.stats import gaussian_kde
# fit an array of size [Ndim, Nsamples]
data = np.vstack([x, y])
kde = gaussian_kde(data)
# evaluate on a regular grid
xgrid = np.linspace(-3.5, 3.5, 40)
ygrid = np.linspace(-6, 6, 40)
Xgrid, Ygrid = np.meshgrid(xgrid, ygrid)
Z = kde.evaluate(np.vstack([Xgrid.ravel(), Ygrid.ravel()]))
# Plot the result as an image
plt.imshow(Z.reshape(Xgrid.shape),
origin='lower', aspect='auto',
extent=[-3.5, 3.5, -6, 6],
cmap='Blues')
cb = plt.colorbar()
cb.set_label("density")
# A default plot legend
x = np.linspace(0, 10, 1000)
fig, ax = plt.subplots()
ax.plot(x, np.sin(x), '-b', label='Sine')
ax.plot(x, np.cos(x), '--r', label='Cosine')
ax.axis('equal')
leg = ax.legend();
# A customized plot legend
ax.legend(loc='upper left', frameon=False)
fig
# Specify the number of columns in the legend
# A two-column plot legend
ax.legend(frameon=False, loc='lower center', ncol=2)
fig
# rounded box (fancybox) or add a shadow, change the transparency
# (alpha value) of the frame, or change the padding around the text
ax.legend(fancybox=True, framealpha=1, shadow=True, borderpad=1)
fig
# Customization of legend elements
# fine-tune which elements and labels appear in the
# legend by using the objects returned by plot commands
y = np.sin(x[:, np.newaxis] + np.pi * np.arange(0, 2, 0.5))
lines = plt.plot(x, y)
# lines is a list of plt.Line2D instances
plt.legend(lines[:2], ['first', 'second']);
# Alternative method
# the legend ignores all elements without a label attribute set
plt.plot(x, y[:, 0], label='first')
plt.plot(x, y[:, 1], label='second')
plt.plot(x, y[:, 2:])
plt.legend(framealpha=1, frameon=True);
# Location, geographic size, and population of California cities
cities = pd.read_csv('data/california_cities.csv')
# Extract the data we're interested in
lat, lon = cities['latd'], cities['longd']
population, area = cities['population_total'], cities['area_total_km2']
# Scatter the points, using size and color but no label
plt.scatter(lon, lat, label=None,
c=np.log10(population), cmap='viridis',
s=area, linewidth=0, alpha=0.5)
plt.axis(aspect='equal')
plt.xlabel('longitude')
plt.ylabel('latitude')
plt.colorbar(label='log$_{10}$(population)')
plt.clim(3, 7)
# Here we create a legend:
# we'll plot empty lists with the desired size and label for area in [100, 300, 500]:
plt.scatter([], [], c='k', alpha=0.3,
s=area, label=str(area) + ' km$^2$')
plt.legend(scatterpoints=1, frameon=False, # THIS IS NOT WORKING
labelspacing=1, title='City Area')
plt.title('California Cities: Area and Population');
If you try to create a second legend using plt.legend() or ax.legend(), it will simply override the first one. We can work around this by creating a new legend artist from scratch, and then using the lower-level ax.add_artist() method to manually add the second artist to the plot.
fig, ax = plt.subplots()
lines = []
styles = ['-', '--', '-.', ':']
x = np.linspace(0, 10, 1000)
for i in range(4):
lines += ax.plot(x, np.sin(x - i * np.pi / 2),
styles[i], color='black')
ax.axis('equal')
# specify the lines and labels of the first legend
ax.legend(lines[:2], ['line A', 'line B'],
loc='upper right', frameon=False)
# Create the second legend and add the artist manually.
from matplotlib.legend import Legend
leg = Legend(ax, lines[2:], ['line C', 'line D'],
loc='lower right', frameon=False)
ax.add_artist(leg);
# the simplest colorbar can be created with the plt.colorbar function
x = np.linspace(0, 10, 1000)
I = np.sin(x) * np.cos(x[:, np.newaxis])
plt.imshow(I)
plt.colorbar();
plt.imshow(I, cmap='gray');
Sequential colormaps: These consist of one continuous sequence of colors (e.g., binary or viridis);
Divergent colormaps: These usually contain two distinct colors, which show positive and negative deviations from a mean (e.g., RdBu or PuOr);
Qualitative colormaps: These mix colors with no particular sequence (e.g., rainbow or jet);
from matplotlib.colors import LinearSegmentedColormap
def grayscale_cmap(cmap):
"""Return a grayscale version of the given colormap"""
cmap = plt.cm.get_cmap(cmap)
colors = cmap(np.arange(cmap.N))
# convert RGBA to perceived grayscale luminance
# cf. http://alienryderflex.com/hsp.html
RGB_weight = [0.299, 0.587, 0.114]
luminance = np.sqrt(np.dot(colors[:, :3] ** 2, RGB_weight))
colors[:, :3] = luminance[:, np.newaxis]
return LinearSegmentedColormap.from_list(cmap.name + "_gray", colors, cmap.N)
def view_colormap(cmap):
"""Plot a colormap with its grayscale equivalent"""
cmap = plt.cm.get_cmap(cmap)
colors = cmap(np.arange(cmap.N))
cmap = grayscale_cmap(cmap)
grayscale = cmap(np.arange(cmap.N))
fig, ax = plt.subplots(2, figsize=(6, 2),
subplot_kw=dict(xticks=[], yticks=[]))
ax[0].imshow([colors], extent=[0, 10, 0, 1])
ax[1].imshow([grayscale], extent=[0, 10, 0, 1])
# The jet colormap and its uneven luminance scale
view_colormap('jet')
# The cubehelix colormap and its luminance
view_colormap('cubehelix')
view_colormap('RdBu')
# Specifying colormap extensions
# we can narrow the color limits and indicate the out-of-bounds
# values with a triangular arrow at the top and bottom by setting
# the extend property.
# make noise in 1% of the image pixels
speckles = (np.random.random(I.shape) < 0.01)
I[speckles] = np.random.normal(0, 3, np.count_nonzero(speckles))
plt.figure(figsize=(10, 3.5))
plt.subplot(1, 2, 1) # row, col, id
plt.imshow(I, cmap='RdBu')
plt.colorbar()
plt.subplot(1, 2, 2) # row, col, id
plt.imshow(I, cmap='RdBu')
plt.colorbar(extend='both')
plt.clim(-1, 1);
# Notice that in the left panel, the default color limits respond
# to the noisy pixels, and the range of the noise completely washes
# out the pattern we are interested in.
Colormaps are by default continuous, but sometimes you’d like to represent discrete values. The easiest way to do this is to use the plt.cm.get_cmap() function, and pass the name of a suitable colormap along with the number of desired bins
plt.imshow(I, cmap=plt.cm.get_cmap('Blues', 6))
plt.colorbar()
plt.clim(-1, 1);
# Sample of handwritten digit data
# load images of the digits 0 through 5 and visualize several of them
from sklearn.datasets import load_digits
digits = load_digits(n_class=6)
fig, ax = plt.subplots(8, 8, figsize=(6, 6)) # 8x8 grid
for i, axi in enumerate(ax.flat):
axi.imshow(digits.images[i], cmap='binary')
axi.set(xticks=[], yticks=[])
# Manifold embedding of handwritten digit pixels
# project the digits into 2 dimensions using IsoMap
from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
projection = iso.fit_transform(digits.data)
# discrete colormap to view the results, setting the ticks
# and clim to improve the aesthetics of the resulting colorbar
# plot the results
plt.scatter(projection[:, 0], projection[:, 1], lw=0.1,
c=digits.target, cmap=plt.cm.get_cmap('cubehelix', 6))
plt.colorbar(ticks=range(6), label='digit value')
plt.clim(-0.5, 5.5)
Four routines for creating subplots in Matplotlib
# Example of an inset axes
ax1 = plt.axes() # standard axes
ax2 = plt.axes([0.65, 0.65, 0.2, 0.2])
# Vertically stacked axes example
fig = plt.figure()
ax1 = fig.add_axes([0.1, 0.5, 0.8, 0.4],
xticklabels=[], ylim=(-1.2, 1.2))
ax2 = fig.add_axes([0.1, 0.1, 0.8, 0.4],
ylim=(-1.2, 1.2))
x = np.linspace(0, 10)
ax1.plot(np.sin(x))
ax2.plot(np.cos(x));
# A plt.subplot() example
for i in range(1, 7):
plt.subplot(2, 3, i) # rows, cols, id
plt.text(0.5, 0.5, str((2, 3, i)),
fontsize=18, ha='center')
# Adjust the spacing between the subplots
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(1, 7):
ax = fig.add_subplot(2, 3, i)
ax.text(0.5, 0.5, str((2, 3, i)),
fontsize=18, ha='center')
# Shared x and y axis in plt.subplots()
# optional keywords sharex and sharey allows to
# specify the relationships between different axes
fig, ax = plt.subplots(2, 3, sharex='col', sharey='row')
# axes are in a two-dimensional array, indexed by [row, col]
for i in range(2):
for j in range(3):
ax[i, j].text(0.5, 0.5, str((i, j)),
fontsize=18, ha='center')
fig
# In comparison to plt.subplot(), plt.subplots() is more
# consistent with Python’s conventional 0-based indexing.
To go beyond a regular grid to subplots that span multiple rows and columns, plt.GridSpec() is the best tool. It is simply a convenient interface that is recognized by the plt.subplot() command.
# Irregular subplots with plt.GridSpec
grid = plt.GridSpec(2, 3, wspace=0.4, hspace=0.3)
plt.subplot(grid[0, 0])
plt.subplot(grid[0, 1:])
plt.subplot(grid[1, :2])
plt.subplot(grid[1, 2]);
# Visualizing multidimensional distributions with plt.GridSpec
# There is a own plotting API in the Seaborn package to build a plot like this.
# Create some normally distributed data mean = [0, 0]
cov = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(mean, cov, 3000).T
# Set up the axes with gridspec
fig = plt.figure(figsize=(6, 6))
grid = plt.GridSpec(4, 4, hspace=0.2, wspace=0.2)
main_ax = fig.add_subplot(grid[:-1, 1:])
y_hist = fig.add_subplot(grid[:-1, 0], xticklabels=[], sharey=main_ax)
x_hist = fig.add_subplot(grid[-1, 1:], yticklabels=[], sharex=main_ax)
# scatter points on the main axes
main_ax.plot(x, y, 'ok', markersize=3, alpha=0.2)
# histogram on the attached axes
x_hist.hist(x, 40, histtype='stepfilled',
orientation='vertical', color='gray')
x_hist.invert_yaxis()
y_hist.hist(y, 40, histtype='stepfilled',
orientation='horizontal', color='gray')
y_hist.invert_xaxis()
Creating a good visualization involves guiding the reader so that the figure tells a story. In some cases, this story can be told in an entirely visual manner, without the need for added text, but in others, small textual cues and labels are necessary. Perhaps the most basic types of annotations you will use are axes labels and titles, but the options go beyond this.
# Average daily births by date
births = pd.read_csv('data/births.csv')
quartiles = np.percentile(births['births'], [25, 50, 75])
mu, sig = quartiles[1], 0.74 * (quartiles[2] - quartiles[0])
births = births.query('(births > @mu - 5 * @sig) & (births < @mu + 5 * @sig)')
births['day'] = births['day'].astype(int)
births.index = pd.to_datetime(10000 * births.year +
100 * births.month +
births.day, format='%Y%m%d')
births_by_date = births.pivot_table('births',
[births.index.month, births.index.day])
births_by_date.index = [pd.datetime(2012, month, day)
for (month, day) in births_by_date.index]
fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax);
# Add labels to the plot
style = dict(size=10, color='gray')
ax.text('2012-1-1', 3950, "New Year's Day", **style)
ax.text('2012-7-4', 4250, "Independence Day", ha='center', **style)
ax.text('2012-9-4', 4850, "Labor Day", ha='center', **style)
ax.text('2012-10-31', 4600, "Halloween", ha='right', **style)
ax.text('2012-11-25', 4450, "Thanksgiving", ha='center', **style)
ax.text('2012-12-25', 3850, "Christmas ", ha='right', **style)
# Label the axes
ax.set(title='USA births by day of year (1969-1988)',
ylabel='average daily births')
# Format the x axis with centered month labels
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));
fig
Matplotlib has a well-developed set of tools that it uses inter‐ nally to perform them (the tools can be explored in the matplotlib.transforms sub‐ module). The average user rarely needs to worry about the details of these transforms.
There are three predefined transforms:
# Comparing Matplotlib’s coordinate systems
fig, ax = plt.subplots(facecolor='lightgray')
ax.axis([0, 10, 0, 10])
# transform=ax.transData is the default, but we'll specify it anyway
ax.text(1, 5, ". Data: (1, 5)", transform=ax.transData)
ax.text(0.5, 0.1, ". Axes: (0.5, 0.1)", transform=ax.transAxes)
ax.text(0.2, 0.2, ". Figure: (0.2, 0.2)", transform=fig.transFigure);
# Comparing Matplotlib’s coordinate systems
# if we change the axes limits, it is only the transData
# coordinates that will be affected, while the others remain stationary
ax.set_xlim(0, 2)
ax.set_ylim(-6, 6)
fig
fig, ax = plt.subplots()
x = np.linspace(0, 20, 1000)
ax.plot(x, np.cos(x))
ax.axis('equal')
ax.annotate('local maximum', xy=(6.28, 1), xytext=(10, 4),
arrowprops=dict(facecolor='black', shrink=0.05))
ax.annotate('local minimum', xy=(5 * np.pi, -1), xytext=(2, -6),
arrowprops=dict(facecolor='black', width=1.5,
connectionstyle="angle3,angleA=0,angleB=-90"));
# Some of the arrowprops possibilities
# Annotated average birth rates by day
fig, ax = plt.subplots(figsize=(12, 4))
births_by_date.plot(ax=ax)
# Add labels to the plot
ax.annotate("New Year's Day", xy=('2012-1-1', 4100), xycoords='data',
xytext=(50, -30), textcoords='offset points',
arrowprops=dict(arrowstyle="->",
connectionstyle="arc3,rad=-0.2"))
ax.annotate("Independence Day", xy=('2012-7-4', 4250), xycoords='data',
bbox=dict(boxstyle="round", fc="none", ec="gray"),
xytext=(10, -40), textcoords='offset points', ha='center',
arrowprops=dict(arrowstyle="->"))
ax.annotate('Labor Day', xy=('2012-9-4', 4850), xycoords='data', ha='center',
xytext=(0, -20), textcoords='offset points')
ax.annotate('', xy=('2012-9-1', 4850), xytext=('2012-9-7', 4850),
xycoords='data', textcoords='data',
arrowprops={'arrowstyle': '|-|,widthA=0.2,widthB=0.2', })
ax.annotate('Halloween', xy=('2012-10-31', 4600), xycoords='data',
xytext=(-80, -40), textcoords='offset points',
arrowprops=dict(arrowstyle="fancy",
fc="0.6", ec="none",
connectionstyle="angle3,angleA=0,angleB=-90"))
ax.annotate('Thanksgiving', xy=('2012-11-25', 4500), xycoords='data',
xytext=(-120, -60), textcoords='offset points',
bbox=dict(boxstyle="round4,pad=.5", fc="0.9"),
arrowprops=dict(arrowstyle="->",
connectionstyle="angle,angleA=0,angleB=80,rad=20"))
ax.annotate('Christmas', xy=('2012-12-25', 3850), xycoords='data',
xytext=(-30, 0), textcoords='offset points',
size=13, ha='right', va="center",
bbox=dict(boxstyle="round", alpha=0.1),
arrowprops=dict(arrowstyle="wedge,tail_width=0.5", alpha=0.1));
# Label the axes
ax.set(title='USA births by day of year (1969-1988)',
ylabel='average daily births')
# Format the x axis with centered month labels
ax.xaxis.set_major_locator(mpl.dates.MonthLocator())
ax.xaxis.set_minor_locator(mpl.dates.MonthLocator(bymonthday=15))
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter('%h'));
ax.set_ylim(3600, 5400);
# Example of logarithmic scales and labels
ax = plt.axes(xscale='log', yscale='log')
print(ax.xaxis.get_major_locator())
print(ax.xaxis.get_minor_locator())
print(ax.xaxis.get_major_formatter())
print(ax.xaxis.get_minor_formatter())
# Both major and minor tick labels have their
# locations specified by a LogLocator (which makes sense
# for a logarithmic plot)
Perhaps the most common tick/label formatting operation is the act of hiding ticks or labels. We can do this using plt.NullLocator() and plt.NullFormatter()
# Plot with hidden tick labels (x-axis) and hidden ticks (y-axis)
ax = plt.axes()
ax.plot(np.random.rand(50))
ax.yaxis.set_major_locator(plt.NullLocator())
ax.xaxis.set_major_formatter(plt.NullFormatter())
# Hiding ticks within image plots
# e.g., display images
fig, ax = plt.subplots(5, 5, figsize=(5, 5))
fig.subplots_adjust(hspace=0, wspace=0)
# Get some face data from scikit-learn
from sklearn.datasets import fetch_olivetti_faces
faces = fetch_olivetti_faces().images
for i in range(5):
for j in range(5):
# Hiding ticks within image plots
ax[i, j].xaxis.set_major_locator(plt.NullLocator())
ax[i, j].yaxis.set_major_locator(plt.NullLocator())
ax[i, j].imshow(faces[10 * i + j], cmap="bone")
# A default plot with ticks
fig, ax = plt.subplots(4, 4, sharex=True, sharey=True)
# Customizing the number of ticks
# plt.MaxNLocator(): maximum number of ticks that will be displayed.
# For every axis, set the x and y major locator
for axi in ax.flat:
axi.xaxis.set_major_locator(plt.MaxNLocator(3))
axi.yaxis.set_major_locator(plt.MaxNLocator(3))
fig
Matplotlib’s default tick formatting can leave a lot to be desired; it works well as a broad default, but sometimes you’d like to do something more.
# A default plot with integer ticks
# Plot a sine and cosine curve
fig, ax = plt.subplots()
x = np.linspace(0, 3 * np.pi, 1000)
ax.plot(x, np.sin(x), lw=3, label='Sine')
ax.plot(x, np.cos(x), lw=3, label='Cosine')
# Set up grid, legend, and limits
ax.grid(True)
ax.legend(frameon=False)
ax.axis('equal')
ax.set_xlim(0, 3 * np.pi);
# Ticks at multiples of pi/2
# MultipleLocator(): locates ticks at a multiple of the number you provide.
ax.xaxis.set_major_locator(plt.MultipleLocator(np.pi / 2))
ax.xaxis.set_minor_locator(plt.MultipleLocator(np.pi / 4))
fig
# Ticks with custom labels
def format_func(value, tick_number): # find number of multiples of pi/2
N = int(np.round(2 * value / np.pi))
if N==0:
return "0"
elif N==1:
return r"$\pi/2$"
elif N==2:
return r"$\pi$"
elif N%2>0:
return r"${0}\pi/2$".format(N)
else:
return r"${0}\pi$".format(N // 2)
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
fig
# Locator class --- Description
# ------------------------------------------------------------------------------------
# NullLocator --- No ticks
# FixedLocator --- Tick locations are fixed
# IndexLocator --- Locator for index plots (e.g., where x = range(len(y)))
# LinearLocator --- Evenly spaced ticks from min to max
# LogLocator --- Logarithmically ticks from min to max
# MultipleLocator --- Ticks and range are a multiple of base
# MaxNLocator --- Finds up to a max number of ticks at nice locations
# AutoLocator --- (Default) MaxNLocator with simple defaults
# AutoMinorLocator --- Locator for minor ticks
# NullFormatter --- No labels on the ticks
# IndexFormatter --- Set the strings from a list of labels
# FixedFormatter --- Set the strings manually for the labels
# FuncFormatter --- User-defined function sets the labels
# FormatStrFormatter --- Use a format string for each value
# ScalarFormatter --- (Default) Formatter for scalar values
# LogFormatter --- Default formatter for log axes
# A histogram in Matplotlib’s default style
plt.style.use('classic')
%matplotlib inline
x = np.random.randn(1000)
plt.hist(x);
# A histogram with manual customizations
# use a gray background
ax = plt.axes(facecolor='#E6E6E6')
ax.set_axisbelow(True)
# draw solid white grid lines
plt.grid(color='w', linestyle='solid')
# hide axis spines
for spine in ax.spines.values():
spine.set_visible(False)
# hide top and right ticks
ax.xaxis.tick_bottom()
ax.yaxis.tick_left()
# lighten ticks and labels
ax.tick_params(colors='gray', direction='out')
for tick in ax.get_xticklabels():
tick.set_color('gray')
for tick in ax.get_yticklabels():
tick.set_color('gray')
# control face and edge color of histogram
ax.hist(x, edgecolor='#E6E6E6', color='#EE6666');
You can adjust this configuration at any time using the plt.rc convenience routine.
# We’ll start by saving a copy of the current rcParams dictionary,
# so we can easily reset these changes in the current session:
IPython_default = plt.rcParams.copy()
# Change settings
from matplotlib import cycler
colors = cycler('color',
['#EE6666', '#3388BB', '#9988DD',
'#EECC55', '#88BB44', '#FFBBBB'])
plt.rc('axes', facecolor='#E6E6E6', edgecolor='none',
axisbelow=True, grid=True, prop_cycle=colors)
plt.rc('grid', color='w', linestyle='solid')
plt.rc('xtick', direction='out', color='gray')
plt.rc('ytick', direction='out', color='gray')
plt.rc('patch', edgecolor='#E6E6E6')
plt.rc('lines', linewidth=2)
# A customized histogram using rc settings
plt.hist(x);
# A line plot with customized styles
for i in range(4):
plt.plot(np.random.rand(10))
In the style module; These stylesheets are formatted similarly to the .matplotlibrc files mentioned earlier, but must be named with a .mplstyle extension.
plt.style.available[:5]
# Change a style temporarily
# with plt.style.context('stylename'):
# make_a_plot()
def hist_and_lines():
np.random.seed(0)
fig, ax = plt.subplots(1, 2, figsize=(11, 4))
ax[0].hist(np.random.randn(1000))
for i in range(3):
ax[1].plot(np.random.rand(10))
ax[1].legend(['a', 'b', 'c'], loc='lower left')
# reset rcParams
plt.rcParams.update(IPython_default);
# Matplotlib’s default style
hist_and_lines()
with plt.style.context('fivethirtyeight'):
hist_and_lines()
with plt.style.context('ggplot'):
hist_and_lines()
with plt.style.context('bmh'):
hist_and_lines()
with plt.style.context('dark_background'):
hist_and_lines()
e.g., publication that does not accept color figures
with plt.style.context('grayscale'):
hist_and_lines()
with plt.style.context('seaborn'):
hist_and_lines()
# Set back to the initial style sheet
plt.style.use('seaborn-whitegrid')
%matplotlib inline
# Turn the 3d plots interactive
#%matplotlib notebook
#%matplotlib notebook
#import matplotlib as mpl
#import matplotlib.pyplot as plt
#%matplotlib notebook
#%matplotlib notebook
# We enable three-dimensional plots by importing the mplot3d
# toolkit, included with the main Matplotlib installation
from mpl_toolkits import mplot3d
# An empty three-dimensional axes
fig = plt.figure()
ax = plt.axes(projection='3d')
# recall that to use interactive figures, you can use %matplotlib notebook
# rather than %matplotlib inline when running this code.
The most basic three-dimensional plot is a line or scatter plot created from sets of (x, y, z) triples. ax.plot3D and ax.scatter3D functions
# Points and lines in three dimensions
ax = plt.axes(projection='3d')
# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')
# Data for three-dimensional scattered points
zdata = 15 * np.random.random(100)
xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='Greens');
def f(x, y):
return np.sin(np.sqrt(x ** 2 + y ** 2))
x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)
X, Y = np.meshgrid(x, y)
print(type(x), type(X))
Z=f(X,Y)
# A three-dimensional contour plot
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.contour3D(X, Y, Z, 50, cmap='binary')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
# Adjusting the view angle for a three-dimensional plot
# An elevation of 60 degrees (i.e., 60 degrees above the x-y plane)
# An azimuth of 35 degrees (i.e., rotated 35 degrees counter-clockwise about the z-axis)
ax.view_init(60, 35)
fig
# A wireframe plot
# These take a grid of values and project it onto the
# specified three-dimensional surface.
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(X, Y, Z, color='black')
ax.set_title('wireframe');
# A three-dimensional surface plot
# A surface plot is like a wireframe plot, but each face
# of the wireframe is a filled poly‐ gon.
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap='viridis', edgecolor='none')
ax.set_title('surface');
# A polar surface plot
# Example of creating a partial polar grid, which when used
# with the surface3D plot can give us a slice into the function
# we’re visualizing
r = np.linspace(0, 6, 20)
theta = np.linspace(-0.9 * np.pi, 0.8 * np.pi, 40)
r, theta = np.meshgrid(r, theta)
X = r * np.sin(theta)
Y = r * np.cos(theta)
Z=f(X,Y)
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap='viridis', edgecolor='none');
For some applications, the evenly sampled grids required by the preceding routines are overly restrictive and inconvenient. In these situations, the triangulation-based plots can be very useful.
# A three-dimensional sampled surface
theta = 2 * np.pi * np.random.random(1000)
r = 6 * np.random.random(1000)
x = np.ravel(r * np.sin(theta))
y = np.ravel(r * np.cos(theta))
z=f(x,y)
ax = plt.axes(projection='3d')
ax.scatter(x, y, z, c=z, cmap='viridis', linewidth=0.5);
# A triangulated surface plot
# ax.plot_trisurf: creates a surface by first finding a
# set of triangles formed between adjacent points
ax = plt.axes(projection='3d')
ax.plot_trisurf(x, y, z,
cmap='viridis', edgecolor='none');
A Möbius strip is similar to a strip of paper glued into a loop with a half-twist. It has only a single side
theta = np.linspace(0, 2 * np.pi, 30)
w = np.linspace(-0.25, 0.25, 8)
w, theta = np.meshgrid(w, theta)
phi = 0.5 * theta
# radius in x-y plane
r=1+w*np.cos(phi)
x = np.ravel(r * np.cos(theta))
y = np.ravel(r * np.sin(theta))
z = np.ravel(w * np.sin(phi))
# triangulate in the underlying parameterization
from matplotlib.tri import Triangulation
tri = Triangulation(np.ravel(w), np.ravel(theta))
ax = plt.axes(projection='3d')
ax.plot_trisurf(x, y, z, triangles=tri.triangles,
cmap='viridis', linewidths=0.2);
ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(-1, 1);
from mpl_toolkits.basemap import Basemap
# A “bluemarble” projection of the Earth
plt.figure(figsize=(8, 8))
m = Basemap(projection='ortho', resolution=None, lat_0=50, lon_0=-100)
m.bluemarble(scale=0.5);
# Plotting data and labels on the map
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution=None,
width=8E6, height=8E6,
lat_0=45, lon_0=-100,)
m.etopo(scale=0.5, alpha=0.5)
# Map (long, lat) to (x, y) for plotting
x, y = m(-122.3, 47.6)
plt.plot(x, y, 'ok', markersize=5)
plt.text(x, y, ' Seattle', fontsize=12);
from itertools import chain
def draw_map(m, scale=0.2):
# draw a shaded-relief image
m.shadedrelief(scale=scale)
# lats and longs are returned as a dictionary
lats = m.drawparallels(np.linspace(-90, 90, 13))
lons = m.drawmeridians(np.linspace(-180, 180, 13))
# keys contain the plt.Line2D instances
lat_lines = chain(*(tup[1][0] for tup in lats.items()))
lon_lines = chain(*(tup[1][0] for tup in lons.items()))
all_lines = chain(lat_lines, lon_lines)
# cycle through these lines and set the desired style
for line in all_lines:
line.set(linestyle='-', alpha=0.3, color='w')
# Cylindrical equal-area projection
# latitude (lat) and longitude (lon)
# lower-left corner (llcrnr) and upper-right corner (urcrnr)
# in units of degrees.
fig = plt.figure(figsize=(8, 6), edgecolor='w')
m = Basemap(projection='cyl', resolution=None,
llcrnrlat=-90, urcrnrlat=90,
llcrnrlon=-180, urcrnrlon=180, )
draw_map(m)
# Pseudo-cylindrical projections
# The Molleweide projection
# The Mollweide projection (projection='moll') is one common
# example of this, in which all meridians are elliptical arcs
fig = plt.figure(figsize=(8, 6), edgecolor='w')
m = Basemap(projection='moll', resolution=None,
lat_0=0, lon_0=0)
draw_map(m)
# Perspective projections
## The orthographic projection
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='ortho', resolution=None,
lat_0=50, lon_0=0)
draw_map(m);
# Conic projections
## The Albers equal-area projection
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution=None,
lon_0=0, lat_0=50, lat_1=45, lat_2=55,
width=1.6E7, height=1.2E7)
draw_map(m)
Example: California Cities
# Scatter plot over a map background
cities = pd.read_csv('data/california_cities.csv')
# Extract the data we're interested in
lat = cities['latd'].values
lon = cities['longd'].values
population = cities['population_total'].values
area = cities['area_total_km2'].values
# 1. Draw the map background
fig = plt.figure(figsize=(8, 8))
m = Basemap(projection='lcc', resolution='h',
lat_0=37.5, lon_0=-119,
width=1E6, height=1.2E6)
m.shadedrelief()
m.drawcoastlines(color='gray')
m.drawcountries(color='gray')
m.drawstates(color='gray')
# 2. scatter city data, with color reflecting population
# and size reflecting area
m.scatter(lon, lat, latlon=True,
c=np.log10(population), s=area,
cmap='Reds', alpha=0.5)
# 3. create colorbar and legend
plt.colorbar(label=r'$\log_{10}({\rm population})$')
plt.clim(3, 7)
# make legend with dummy points
for a in [100, 300, 500]:
plt.scatter([], [], c='k', alpha=0.5, s=a,
label=str(a) + ' km$^2$')
plt.legend(scatterpoints=1, frameon=False,
labelspacing=1, loc='lower left');
There are several valid complaints about Matplotlib that often come up:
An answer to these problems is Seaborn.
# Create some data
rng = np.random.RandomState(0)
x = np.linspace(0, 10, 500)
y = np.cumsum(rng.randn(500, 6), 0)
# Plot the data with Matplotlib defaults
# Data in Matplotlib’s default style
plt.plot(x, y)
plt.legend('ABCDEF', ncol=2, loc='upper left');
import seaborn as sns
sns.set()
# seaborn it can also overwrite Matplotlib’s default parameters
# and in turn get even simple Matplotlib scripts to produce
# vastly superior output.
# same plotting code as above!
plt.plot(x, y)
plt.legend('ABCDEF', ncol=2, loc='upper left');
# Histograms for visualizing distributions
data = np.random.multivariate_normal([0, 0], [[5, 2], [2, 2]], size=2000)
data = pd.DataFrame(data, columns=['x', 'y'])
for col in 'xy':
plt.hist(data[col], density=True, alpha=0.5)
# Kernel density estimates for visualizing distributions
for col in 'xy':
sns.kdeplot(data[col], shade=True)
# Kernel density estimates (KDE) and histograms plotted together
sns.distplot(data['x'])
sns.distplot(data['y']);
# A two-dimensional kernel density plot
# If we pass two vectors to kdeplot,
# we will get a two-dimensional visualization of the data
sns.kdeplot(data['x'], data['y']);
# A joint distribution plot with a two-dimensional kernel density estimate
# We can see the joint distribution and the marginal
# distributions together using sns.jointplot.
with sns.axes_style('white'):
sns.jointplot("x", "y", data, kind='kde');
# A joint distribution plot with a hexagonal bin representation
with sns.axes_style('white'):
sns.jointplot("x", "y", data, kind='hex')
# A pair plot showing the relationships between four variables
iris = sns.load_dataset("iris")
sns.pairplot(iris, hue='species', height=2.5);
Sometimes the best way to view data is via histograms of subsets. Seaborn’s FacetGrid makes this extremely simple. We’ll take a look at some data that shows the amount that restaurant staff receive in tips based on various indicator data.
# An example of a faceted histogram
tips = sns.load_dataset('tips')
tips['tip_pct'] = 100 * tips['tip'] / tips['total_bill']
grid = sns.FacetGrid(tips, row="sex", col="time", margin_titles=True)
grid.map(plt.hist, "tip_pct", bins=np.linspace(0, 40, 15));
Factor plots can be useful for this kind of visualization as well. This allows you to view the distribution of a parameter within bins defined by any other parameter
# An example of a factor plot, comparing distributions given
# various discrete factors
with sns.axes_style(style='ticks'):
g = sns.catplot("day", "total_bill", "sex", data=tips, kind="box")
g.set_axis_labels("Day", "Total Bill");
# A joint distribution plot
with sns.axes_style('white'):
sns.jointplot("total_bill", "tip", data=tips, kind='hex')
# The joint plot can even do some automatic kernel density estimation and regression
sns.jointplot("total_bill", "tip", data=tips, kind='reg');
Time series can be plotted with sns.factorplot
planets = sns.load_dataset('planets')
planets.head(2)
# A histogram as a special case of a factor plot
with sns.axes_style('white'):
g = sns.catplot("year", data=planets, aspect=2,
kind="count", color='steelblue')
g.set_xticklabels(step=5)
#### Number of planets discovered by year and type
with sns.axes_style('white'):
g = sns.catplot("year", data=planets, aspect=4.0, kind='count',
hue='method', order=range(2001, 2015))
g.set_ylabels('Number of Planets Discovered')
data = pd.read_csv('data/marathon-data.csv')
data.head()
data.dtypes
# There is no datetime in pandas without date, so it is converted to an object with dt.time
data['split'] = pd.to_datetime(data['split'], format='%H:%M:%S')
data['final'] = pd.to_datetime(data['final'], format='%H:%M:%S')
data.head()
data.dtypes
def convert_to_seconds(time):
return time.dt.hour * 3600 + time.dt.minute * 60 + time.dt.second
data['split_sec'] = convert_to_seconds(data['split'])
data['final_sec'] = convert_to_seconds(data['final'])
data.head()
with sns.axes_style('white'):
g = sns.jointplot("split_sec", "final_sec", data, kind='hex')
g.ax_joint.plot(np.linspace(4000, 16000),
np.linspace(8000, 32000), ':k')
# The relationship between the split for the first half-marathon
# and the finishing time for the full marathon
# Let’s create another column in the data, the split fraction,
# which measures the degree to which each runner negative-splits
# or positive-splits the race:
data['split_frac'] = 1 - 2 * data['split_sec'] / data['final_sec']
data.head()
# The distribution of split fractions; 0.0 indicates a runner
# who completed the first and second halves in identical times
sns.distplot(data['split_frac'], kde=False);
plt.axvline(0, color="k", linestyle="--");
# Out of nearly 40,000 participants, there were only
# 250 people who negative-split their marathon.
sum(data.split_frac < 0)
# The relationship between quantities within the marathon dataset
# Check if there are any correlations
g = sns.PairGrid(data, vars=['age', 'split_sec', 'final_sec', 'split_frac'],
hue='gender', palette='RdBu_r')
g.map(plt.scatter, alpha=0.1)
g.add_legend();
# The distribution of split fractions by gender
sns.kdeplot(data.split_frac[data.gender=='M'], label='men', shade=True)
sns.kdeplot(data.split_frac[data.gender=='W'], label='women', shade=True)
plt.xlabel('split_frac');
# A violin plot showing the split fraction by gender
sns.violinplot("gender", "split_frac", data=data,
palette=["lightblue", "lightpink"]);
# A new column in the array that specifies the decade of age that each person is in
data['age_dec'] = data.age.map(lambda age: 10 * (age // 10))
data.head()
men = (data.gender == 'M')
women = (data.gender == 'W')
# A violin plot showing the split fraction by gender and age
with sns.axes_style(style=None):
sns.violinplot("age_dec", "split_frac",
hue="gender", data=data,
split=True, inner="quartile",
palette=["lightblue", "lightpink"]);
# Also surprisingly, the 80-year-old women seem to outperform
# everyone in terms of their split time. This is probably due
# to the fact that we’re estimating the distribution from small
# numbers, as there are only a handful of runners in that range:
(data.age > 80).sum()
# Split fraction versus finishing time by gender
# lmplot automatically fits a linear regression to the data
g = sns.lmplot('final_sec', 'split_frac', col='gender', data=data,
markers=".", scatter_kws=dict(color='c'))
g.map(plt.axhline, y=0.1, color="k", ls=":");